{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers.convolutional import Conv2D\n",
    "from keras.layers.pooling import MaxPooling2D, AveragePooling2D\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras import backend as K\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pipeline 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "random_seed = 1005\n",
    "data_in_shape = (8, 8, 2)\n",
    "\n",
    "layers = [\n",
    "    Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True),\n",
    "    MaxPooling2D(pool_size=(2,2), strides=None, padding='valid', data_format='channels_last')\n",
    "]\n",
    "\n",
    "input_layer = Input(shape=data_in_shape)\n",
    "x = layers[0](input_layer)\n",
    "for layer in layers[1:-1]:\n",
    "    x = layer(x)\n",
    "output_layer = layers[-1](x)\n",
    "model = Model(inputs=input_layer, outputs=output_layer)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['pipeline_05'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../test/data/pipeline/05.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"pipeline_05\": {\"input\": {\"data\": [0.318429, -0.858397, -0.059042, 0.68597, -0.649837, -0.575506, -0.564232, -0.922183, 0.440614, -0.111226, -0.319004, 0.744745, 0.189863, -0.126804, 0.616934, -0.196828, -0.463643, -0.28664, 0.297658, 0.692912, -0.648859, -0.305725, 0.358888, 0.494721, -0.949408, 0.686943, 0.186782, 0.274151, -0.81428, 0.475143, 0.172363, 0.474509, 0.169688, -0.181608, -0.280812, -0.756105, -0.715837, -0.666569, -0.813171, 0.706929, 0.651385, 0.462418, -0.124894, 0.842424, 0.144887, -0.127564, -0.425119, 0.479301, -0.136498, 0.717532, 0.774194, -0.922236, 0.811165, -0.118287, -0.850532, 0.393996, -0.901293, 0.700216, -0.54517, -0.892511, -0.756845, 0.534143, -0.276233, 0.357844, 0.0165, -0.297826, -0.68739, -0.786413, -0.770103, -0.525359, 0.576525, 0.743913, 0.070168, 0.465127, -0.305365, 0.038138, 0.949944, -0.256667, -0.745166, -0.43527, -0.191414, -0.640097, -0.381157, -0.373038, 0.08376, -0.31846, 0.255438, 0.30699, 0.378141, 0.152443, -0.232152, -0.43222, 0.959384, 0.970512, -0.209297, 0.963941, -0.289674, 0.849551, -0.135004, 0.621719, 0.857205, -0.75747, 0.993436, -0.509751, 0.519766, -0.687053, -0.250223, -0.105581, -0.447451, 0.434267, -0.312283, -0.526904, 0.94779, -0.335303, -0.940595, -0.021191, 0.996289, 0.453149, -0.055831, 0.13987, 0.707184, -0.61947, 0.244979, -0.933309, 0.474678, 0.285113, -0.248854, -0.966705], \"shape\": [8, 8, 2]}, \"weights\": [{\"data\": [0.318429, -0.858397, -0.059042, 0.68597, -0.649837, -0.575506, -0.564232, -0.922183, 0.440614, -0.111226, -0.319004, 0.744745, 0.189863, -0.126804, 0.616934, -0.196828, -0.463643, -0.28664, 0.297658, 0.692912, -0.648859, -0.305725, 0.358888, 0.494721, -0.949408, 0.686943, 0.186782, 0.274151, -0.81428, 0.475143, 0.172363, 0.474509, 0.169688, -0.181608, -0.280812, -0.756105, -0.715837, -0.666569, -0.813171, 0.706929, 0.651385, 0.462418, -0.124894, 0.842424, 0.144887, -0.127564, -0.425119, 0.479301, -0.136498, 0.717532, 0.774194, -0.922236, 0.811165, -0.118287, -0.850532, 0.393996, -0.901293, 0.700216, -0.54517, -0.892511, -0.756845, 0.534143, -0.276233, 0.357844, 0.0165, -0.297826, -0.68739, -0.786413, -0.770103, -0.525359, 0.576525, 0.743913], \"shape\": [3, 3, 2, 4]}, {\"data\": [0.486255, -0.547151, 0.285068, 0.764711], \"shape\": [4]}], \"expected\": {\"data\": [2.841934, 0.200809, 2.326603, 1.624898, 1.502095, 0.17491, 1.884184, 4.888361, 0.271614, 0.0, 1.926391, 2.824068, 3.510188, 0.710325, 1.212564, 2.914058, 0.478862, 1.466825, 1.174033, 2.033663, 1.757399, 0.569198, 1.703989, 0.0, 2.821799, 1.145426, 1.413261, 1.517435, 1.856695, 1.240259, 1.322095, 0.0, 1.829474, 0.281766, 2.021509, 2.110038], \"shape\": [3, 3, 4]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
